%% Calculate P_hat

function [P_hat_mnj,b_hat,b_hat_f,P_M_hat,P_M_f_hat,mu_hat_mnjf] = PriceLoopCES_fun_approx_Olig(P_hat_mnj0,w_hat,pi_mnj,...
       pi_M, pi_l,pi_mnj_sorted,pi_M_f,pi_l_f,mu_hat_mnjf0)


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Author: Produced for JdG,AL,IM by Christopher Evans at UPF 
%
% Solves the relative change in prices. This function takes wage and price
% guess to calculate the price change due to shock. This function will be
% looped over. 
%
% NEW: For all code, we will change the subscripts to match those in
% master_notes3.pdf and the draft. I.e., we will change to {mn,ij} pairs,
% which represent country pair {mn} and sector pair {ij}. Global variables 
% changed and slight modification to file structure for calling up data.
% We also get rid of gamma_type variable. What previously _2.csv. files
% slightly as well as adjusting in our new directory.
%
% MAJOR CHANGE: Solution of price index now based on CES cost function,
% which implies a change of computation of price function in terms of
% functional form. Will refer to master_notes3 for formula and solve for.
%
% This function will also be used for both the baseline setting as well as
% the heterogeneous calibration. Therefore, the initial labor and
% intermediate import shares will have to adjust away from the data (alpha,
% gamma) in the inital wedge exercise. Therefore, the function is set up to
% run with a pi_l and pi_M rather than data shares.
%
% b_hat (cost funtion change), which is a function of w_hat and P^M_hat,
% which in turn is a function of P_mnj_hat. Given formulae, we do not need
% to update for labor and intermediate good shares in this function in
% solivng for the bilateral price change. Will also output to use in future
% loops
%
% We also get rid of using the .(ISO{}) cell call up for the gammas and
% instead use a large matrix saved in Step2. This will help save on loops
% and speed up the proceedure.
%
% Last Updated: 5/02/2019 by IM & JDG
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

global alphaT rhoT lambdaT etaT rho sigma lambda eta varphi ...
    tol maxit Pi_hat_n D_hat_n Xi_hat_mnjf Xi_hat_mnj a_hat_f a_hat ...
    N J FI_J france countrysecD firmsecD_sorted start start_sorted ...
    sum_by_country_dummy  ... 
    sL_n sPi_n sD_n alpha_WIOD_j alpha_WIOD_francej labour_share_PWT
 
pfmax = 1;
it    = 1;   

% Initialize outcome variables to initial valued fed into function

P_hat_knj = P_hat_mnj0;
pi_mnjf0 = pi_mnj_sorted;
mu_hat_mnjf = mu_hat_mnjf0;

fprintf('   PRICE ITERATION=%d \n',it);

while (it <= maxit) && (pfmax > tol)   
   %% Sector-level operations
   
   % calculating the cost of input bundle, b 
   
   % Calculating P_hat for all countries
   
   % P_hat_mnj: extending the matrix by J for intial prices
   log_P_hat_mnj=log(P_hat_knj);
   log_P_hat_mnj0_repmat = kron(log_P_hat_mnj,ones(1,J));
   
   % P_hat_weta = pi_M.*(P_hat_mnj_repmat)^(1-eta);

   log_P_hat_weta = pi_M.*log_P_hat_mnj0_repmat;
   
   % Summing by k and by i, so we just need to sum along the rows, this
   % will give us a transposed vector [ 1xJ*N ] and the price of
   % intermediate goods ==> P_M_hat
   % Need to adjust for fact that some country-sector pairs never import,
   % so will ahve a 0 P_M_hat raised to a negative number.
   
   log_P_M_hat = squeeze(sum(log_P_hat_weta,1));
   P_M_hat = exp(log_P_M_hat);
   
   % Multiplying by the intermediate share afer raising to 1-lambda
   
   log_P_M_hat_wlambda = log_P_M_hat.*(1-pi_l)';  
   
   % Repmat so that we have it by n, this means extending it by N
   
   log_P_M_hat_wlambda_repmat = repmat(log_P_M_hat_wlambda', [1 N]);
   
   % Need to expand to mnj, but have to keep order so use a dummy (cant
   % just use repmat here).
   log_w_hat = log(w_hat);
   log_w_hat_to_vector=countrysecD*log_w_hat';
    
   % Taking wage to power of 1-lambda and multiplying by labor share
   
   log_w_hat_wlambda = pi_l.*log_w_hat_to_vector;
   
   % Expanding the wage term by N so it is a [N*JxN]
   % matrix
   
   log_w_hat_wlambda_repmat = repmat(log_w_hat_wlambda,[1 N]);

   % Create total cost: (N*J)*N matrix
   log_b_hat_rep = (log_w_hat_wlambda_repmat + log_P_M_hat_wlambda_repmat);
   
   % To output for later use -- non-repeated vector
   b_hat_rep = exp(log_b_hat_rep);
   b_hat = b_hat_rep(:,1);

   % Assembling components into P_hat_mnj given b matrix, shocks and  
   % bilateral shares of sectors/firms
   
   % Generate matrices that are needed for shocks and parameters
   
   a_hat_repmat = repmat(a_hat,[1 N]);
   rho_mnj_repmat = repmat(rho, [N N]);
   
   P_hat_mnj = (pi_mnj.*Xi_hat_mnj.*(b_hat_rep.*a_hat_repmat)...
      .^(1-rho_mnj_repmat)).^(1./(1-rho_mnj_repmat));
   
    
   %% French firm-level operations
   
   % Key difference here is that we must deal with firm-specific labor,
   % import, and export shares as well as potential idiosyncratic taste 
   % and productivity shocks. Unlike C-D production, these factors come in
   % when constructing initial cost function, so need to do everything from
   % scratch following similar code as above but using firm-level data

   % Initialising a matrix we will use to sum over k countries for France
   % This is done for the intermediate price index P^M
   %log_P_hat_mnj=log(P_hat_knj(:,france)');
   log_P_hat_mnjf0_repmat = kron(log(P_hat_knj(:,france)'),ones(FI_J,1));
   
   
   % P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);
   
   %log_P_hat_f_weta = pi_M_f.*log_P_hat_mnjf0_repmat;
   
   % P_M_f_hat = sum_k_pi_M_P_f_hat: Sums by countries and sectors
  % log_P_M_f_hat = sum(pi_M_f.*log_P_hat_mnjf0_repmat,2);
%   P_M_f_hat = exp(log_P_M_f_hat);
    
   % Multiplying by the one minus alpha afer raising to 1-lambda
   
   log_P_M_f_hat_wlambda = sum(pi_M_f.*log_P_hat_mnjf0_repmat,2).*(1-pi_l_f);
    
   % Clearing to save memory 
   
   %clear P_hat_mnjf_0 P_hat_f_weta sum_k_P_hat_f_weta 
   
   % Taking wage to power of 1-lambda and multiplying by labor share
   %log_w_hat_f = log(w_hat(france));
   log_w_hat_f_wlambda = pi_l_f.*repmat(log(w_hat(france)),[FI_J 1]);

   % Create total cost: F*1 matrix
   
   log_b_hat_f = (log_w_hat_f_wlambda + log_P_M_f_hat_wlambda);
   b_hat_f = exp(log_b_hat_f);
   % Generate firm-specific prices for given country-sector destination

   % Update market shares
   rho_f=firmsecD_sorted*rho;
   inside_brackets = (mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-rho_f);
   pi_xi_brackets = Xi_hat_mnjf.*inside_brackets.*pi_mnjf0;
   sum_pi_xi_brackets = (firmsecD_sorted')*pi_xi_brackets;
   DEN_f = firmsecD_sorted*sum_pi_xi_brackets;
   pi_mnjf1 = pi_xi_brackets./DEN_f;
   pi_hat_mnjf = ones(FI_J,N);
   pi_hat_mnjf(pi_mnjf0==0) = 1;
   pi_hat_mnjf(pi_mnjf0~=0) = pi_mnjf1(pi_mnjf0~=0)./pi_mnjf0(pi_mnjf0~=0);

  
   % Markup-adjustments
    repmat_sigma = repmat(firmsecD_sorted * sigma,[1,N]);
    repmat_rho = repmat(firmsecD_sorted * rho,[1,N]);
    mu_mnjf0 = repmat_sigma.*repmat_rho./(repmat_sigma.*(repmat_rho-1)-(repmat_rho-repmat_sigma).*pi_mnjf0);
    mu_hat_mnjf = 1./((repmat_rho-1)./repmat_rho.*mu_mnjf0+(1-(repmat_rho-1)./repmat_rho.*mu_mnjf0).*pi_hat_mnjf);
   
   
   % Price adjustments 
   % Multiplying the the expression by a technology (a) shock 
%   inside_brackets = (mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-repmat(rho_f, [1 N]));

   % Multiplying the previous expression by pi_mnj(f) Xi_hat_mnj(f)
      
%   pi_xi_brackets = pi_mnjf0.*...
%          Xi_hat_mnjf.*(mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-repmat(rho_f, [1 N]));
   
   % Summing by using the firmsecD_sorted matrix
   
    sum_pi_xi_brackets=(firmsecD_sorted')*(pi_mnjf0.*...
          Xi_hat_mnjf.*(mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-repmat(rho_f, [1 N])));
   
   % looping over sectors to put to the power of rho. This will give us the
   % price hat in france.
   
    for j=1:J
      P_hat_mnj_france(j,:)=sum_pi_xi_brackets(j,:).^(1/(1-rho(j)));
    end

   % Replacting the price hat for france with the price hat calculated
   % using firm data
   
   P_hat_mnj(start(france):start(france+1)-1,:)=P_hat_mnj_france;
    
   %% Ending price loop
   
   % This should not be necessary, it is to make sure we don't have any NaNs,
   % however if NaNs feature then there is an issue in the matlab files
   % preceding this one
   
   P_hat_mnj(~isfinite(P_hat_mnj))=1;
   
   % checking distance from new calculating of prices and the initial guess
   
   % pfdev: Calculating the deviation from the guess and updated price change
   
   P_hat_mnj_vec=reshape(P_hat_mnj,[N*N*J 1]);
   P_hat_knj_vec=reshape(P_hat_knj,[N*N*J 1]);
      
   pfdev=abs(P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
   
   pfmax=norm(pfdev);
   
   % Updated the price guess
   
   P_hat_knj = P_hat_mnj;
   
   % Updated the markup
   
%    % Small interation price relative to previous step given a nu_P, where
%    % nu_P is the adjustment we might make
%    
%    nu_P = .5;
%    
%    p_diff = (P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
% 
%    P_hat_knj_vec = P_hat_mnj_vec + nu_P*P_hat_mnj_vec.*p_diff; % Iteration of new price
%    
%    p_diff1 = (P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
%    
%    P_hat_knj = reshape(P_hat_knj_vec ,[N*J,N]);
%    
%    % Calculate distance as max between non-adjusted and adjusted new P_hat 
%    % and old P_hat
% 
%    pfmax = max(max(abs(p_diff)),max(abs(p_diff1)));
 
   % it: Iteration number
   it       = it + 1;
   
   fprintf('   PRICE ITERATION=%d        Max price distance=%d  \n',it-1,pfmax);
end

%% Generate final output after convergence

P_hat_mnj = P_hat_knj;

% b_hat

P_hat_mnj0_repmat = kron(P_hat_knj,ones(1,J));

% P_hat_weta = pi_M.*(P_hat_mnj_repmat)^(1-eta);
%log_P_hat_mnj=log(P_hat_knj);
%log_P_hat_mnj0_repmat = kron(log(P_hat_knj),ones(1,J));
%log_P_hat_weta = pi_M.*log_P_hat_mnj0_repmat;
log_P_M_hat = squeeze(sum( pi_M.*(kron(log(P_hat_knj),ones(1,J))),1));
P_M_hat = exp(log_P_M_hat);   
P_M_hat(isinf(P_M_hat)) = 1;

% Multiplying by the intermediate share afer raising to 1-lambda
   
%log_P_M_hat_wlambda = log_P_M_hat.*(1-pi_l)';  
log_P_M_hat_wlambda_repmat = repmat((log_P_M_hat.*(1-pi_l)')', [1 N]);
%log_w_hat = log(w_hat);
%log_w_hat_to_vector=countrysecD*(log(w_hat))';
%log_w_hat_wlambda = pi_l.*(countrysecD*(log(w_hat))');
log_w_hat_wlambda_repmat = repmat(pi_l.*(countrysecD*(log(w_hat))'),[1 N]);

% Create total cost: (N*J)*N matrix

log_b_hat_rep = (log_w_hat_wlambda_repmat + log_P_M_hat_wlambda_repmat);
b_hat_rep = exp(log_b_hat_rep);
   
% To output for later use -- non-repeated vector
b_hat = b_hat_rep(:,1);

% P_M_f_hat and b_hat_f
P_M_f_hat = zeros(size(firmsecD_sorted,1),1);
%log_P_hat_mnj=log(P_hat_knj(:,france)');
%log_P_hat_mnjf0_repmat = kron(log(P_hat_knj(:,france)'),ones(FI_J,1));

% P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);

%log_P_hat_f_weta = pi_M_f.*(kron(log(P_hat_knj(:,france)'),ones(FI_J,1)));

% P_M_f_hat = sum_k_pi_M_P_f_hat: Sums by countries and sectors

log_P_M_f_hat = sum((pi_M_f.*(kron(log(P_hat_knj(:,france)'),ones(FI_J,1)))),2);
P_M_f_hat = exp(log_P_M_f_hat);

% Multiplying by the one minus alpha afer raising to 1-lambda

log_P_M_f_hat_wlambda = log_P_M_f_hat.*(1-pi_l_f);

% Clearing to save memory 

clear P_hat_mnjf_0 P_hat_f_weta sum_k_P_hat_f_weta 

% Taking wage to power of 1-lambda and multiplying by labor share

%log_w_hat_f = log(w_hat(france));
log_w_hat_f_wlambda = pi_l_f.*repmat(log(w_hat(france)),[FI_J 1]);
   
% Create total cost: F*1 matrix
log_b_hat_f = (log_w_hat_f_wlambda + log_P_M_f_hat_wlambda);
b_hat_f = exp(log_b_hat_f);

% Update market shares
rho_f=firmsecD_sorted*rho;
%inside_brackets = (mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-rho_f);
pi_xi_brackets = Xi_hat_mnjf.*((mu_hat_mnjf.*repmat(b_hat_f, [1 N]).*repmat(a_hat_f, [1 N])).^(1-rho_f)).*pi_mnjf0;
%sum_pi_xi_brackets = (firmsecD_sorted')*(pi_xi_brackets);
%DEN_f = firmsecD_sorted*sum_pi_xi_brackets;
pi_mnjf1 = pi_xi_brackets./(firmsecD_sorted*((firmsecD_sorted')*(pi_xi_brackets)));
 pi_hat_mnjf = ones(FI_J,N);
pi_hat_mnjf(pi_mnjf0==0) = 1;
pi_hat_mnjf(pi_mnjf0~=0) = pi_mnjf1(pi_mnjf0~=0)./pi_mnjf0(pi_mnjf0~=0);
   
   
   % Markup-adjustments
repmat_sigma = repmat(firmsecD_sorted * sigma,[1,N]);
repmat_rho = repmat(firmsecD_sorted * rho,[1,N]);
mu_mnjf0 = repmat_sigma.*repmat_rho./(repmat_sigma.*(repmat_rho-1)-(repmat_rho-repmat_sigma).*pi_mnjf0);
mu_hat_mnjf = 1./((repmat_rho-1)./repmat_rho.*mu_mnjf0+(1-(repmat_rho-1)./repmat_rho.*mu_mnjf0).*pi_hat_mnjf);

    
% disp('b_hat_f')
% [min(b_hat_f) max(b_hat_f)] 
end